import numpy as np
import pickle
import networkx as nx
import torch
import random
import os

def compute_all_pair_shortest_paths(G):
    return dict(nx.all_pairs_shortest_path_length(G))

def compute_distances_to_pivots(shortest_paths, G, pivots):
    pivot_distances = []
    for node in G.nodes():
        node_distances = [shortest_paths[node].get(pivot, float('inf')) for pivot in pivots]
        pivot_distances.append(node_distances)
    pivot_distances = torch.tensor(pivot_distances, dtype=torch.float)
    return pivot_distances


def select_pivots_with_random_initialization(shortest_paths, G, num_pivots=80):
    pivots = []

    all_nodes = list(G.nodes())
    pivots = random.sample(all_nodes, num_pivots)

    return pivots


def calculate_average_error(shortest_paths, G, pivots2, pivots1):
    total_error = 0
    count = 0

    pivot_dists = np.array([
        [shortest_paths[node].get(p, float('inf')) for p in pivots1]
        for node in G.nodes()
    ])


    max_dist_to_pivots = np.max(pivot_dists, axis=1)


    for node in G.nodes():
        dist_to_pivots1 = pivot_dists[node]


        for i in range(len(dist_to_pivots1)):
            for j in range(i + 1, len(dist_to_pivots1)):
                true_dist = shortest_paths[i].get(j, float('inf'))
                bound = max_dist_to_pivots[node]  
                if true_dist != float('inf'):
                    relative_error = np.abs(true_dist - bound) / true_dist
                    total_error += relative_error
                    count += 1

    return total_error / count if count != 0 else 0


def select_pivots_with_evolution(shortest_paths, G, num_pivots=80):
    pivots = select_pivots_with_random_initialization(shortest_paths, G, num_pivots)

    max_avg_error = float('inf')
    best_pivots = pivots


    for _ in range(80):  
        for i in range(len(pivots)):
            new_pivots = pivots[:i] + pivots[i+1:]
            print(new_pivots)
        
            new_pivots.append(random.choice(list(set(G.nodes()) - set(pivots))))


            avg_error = calculate_average_error(shortest_paths, G, pivots, new_pivots)
            print(avg_error)
            if avg_error < max_avg_error:
                max_avg_error = avg_error
                best_pivots = new_pivots

        pivots = best_pivots

    return pivots


def main():
    with open('dataset/Cora/cora.pkl', 'rb') as f:
        G = pickle.load(f)

    shortest_paths_file = 'dataset/Cora/shortest_paths.pkl'
    if os.path.exists(shortest_paths_file):
        with open(shortest_paths_file, 'rb') as f:
            shortest_paths = pickle.load(f)
    else:
        shortest_paths = compute_all_pair_shortest_paths(G)
        with open(shortest_paths_file, 'wb') as f:
            pickle.dump(shortest_paths, f)

    pivots = select_pivots_with_evolution(shortest_paths, G, 80)


    with open('dataset/Cora/pivot_nodes.pkl', 'wb') as f:
        pickle.dump(pivots, f)


    node_embeddings = compute_distances_to_pivots(shortest_paths, G, pivots)


    with open('dataset/Cora/node_embeddings.pkl', 'wb') as f:
        pickle.dump(node_embeddings, f)


if __name__ == "__main__":
    main()
